# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate data for mcore attention."""

import numpy as np

NUM_BLOCKS = 128
BLOCK_SIZE = 16
BATCH_SIZE = 2
PREFILL_SEQ_LEN = 2
DECODE_SEQ_LEN = 1
HIDDEN_SIZE = 64
NUM_HEADS = 2
MAX_SEQ_LEN = PREFILL_SEQ_LEN + DECODE_SEQ_LEN
np.random.seed(2025)


def get_init_params(n_kv_heads):
    """Generate initialization parameters"""

    hidden_states = np.random.normal(0, 0.01, (BATCH_SIZE,
                                               MAX_SEQ_LEN, NUM_HEADS, HIDDEN_SIZE // NUM_HEADS)).astype(np.float32)
    prefill_hidden_states = np.zeros((BATCH_SIZE, PREFILL_SEQ_LEN,
                                      NUM_HEADS, HIDDEN_SIZE // NUM_HEADS))
    decoder_hidden_states = np.zeros((BATCH_SIZE, 1, NUM_HEADS,
                                      HIDDEN_SIZE // NUM_HEADS))
    # actual num_tokens_per_batch
    q_seq_lens = np.array([PREFILL_SEQ_LEN + 1] * BATCH_SIZE, dtype=np.int32)
    slot_mapping_list = []

    qkv_weight = np.random.normal(0, 0.01, ((NUM_HEADS + 2 * n_kv_heads) *
                                            (HIDDEN_SIZE // NUM_HEADS), HIDDEN_SIZE)).astype(np.float32)
    proj_weight = np.random.normal(0, 0.01, (HIDDEN_SIZE, NUM_HEADS *
                                             (HIDDEN_SIZE // NUM_HEADS))).astype(np.float32)

    for i, q_seq_len in enumerate(q_seq_lens):
        prefill_hidden_states[i, 0:(q_seq_len - 1), :, :] = hidden_states[i, 0:(q_seq_len - 1), :, :]
        decoder_hidden_states[i, :, :, :] = hidden_states[i, (q_seq_len - 1):q_seq_len, :, :]

    def _num_tokens_to_min_blocks(num_tokens: int, block_size: int) -> int:
        return (num_tokens + block_size) // block_size

    num_blocks_list = [
        _num_tokens_to_min_blocks(num_tokens, BLOCK_SIZE)
        for num_tokens in q_seq_lens
    ]
    block_table_pad_tokens = 10
    max_block_table_len = max(num_blocks_list) + block_table_pad_tokens
    # batch_size x max_num_blocks_per_seq
    block_tables = np.ones((BATCH_SIZE, max_block_table_len)) * -1
    # batch_size x num_tokens_per_seq
    slot_mapping_list = []
    # Compute uppermost address of block table
    total_cache_blocks = sum(num_blocks_list)
    block_base_idx = total_cache_blocks
    for sdx, num_tokens in enumerate(q_seq_lens):
        num_blocks = num_blocks_list[sdx]
        block_table = list(
            range(block_base_idx, block_base_idx - num_blocks, -1))
        for idx, value in enumerate(block_table):
            block_tables[sdx][idx] = value
        for idx in range(num_tokens):
            mapping_value = (
                idx % BLOCK_SIZE) + block_table[idx // BLOCK_SIZE] * BLOCK_SIZE
            slot_mapping_list.append(mapping_value)

        block_base_idx -= num_blocks

    prefill_slot_mapping = []
    decoder_slot_mapping = []
    base_idx = 0
    for seq_len in q_seq_lens:
        prefill_slot_mapping.extend(slot_mapping_list[base_idx:(base_idx + seq_len - 1)])
        decoder_slot_mapping.append(slot_mapping_list[base_idx + seq_len - 1])
        base_idx += seq_len

    return {
        "qkv_weight": qkv_weight,
        "proj_weight": proj_weight,
        "prefill_hidden_states": prefill_hidden_states,
        "decoder_hidden_states": decoder_hidden_states,
        "prefill_slot_mapping": prefill_slot_mapping,
        "decoder_slot_mapping": decoder_slot_mapping,
        "block_tables": block_tables
    }


def get_golden() -> dict[str, np.ndarray]:
    """Generate golden data for test."""
    prefill_output_1 = np.array([[4.87042189e-06, -3.26830850e-05, 6.33480740e-05,
                                  -2.74635004e-05, -4.03012891e-05, -6.50920774e-05,
                                  -6.42207160e-05, 1.07843216e-05, 3.30012335e-05,
                                  -4.96758548e-05, -3.03016805e-05, 6.17170954e-05,
                                  -9.01401800e-05, 2.80463064e-05, 4.18824748e-05,
                                  -2.72499256e-06, 4.47683669e-05, 2.03576838e-05,
                                  -6.06367348e-06, -4.28043932e-05, 3.76185671e-05,
                                  7.46326896e-06, 6.21303334e-05, -5.02910807e-05,
                                  2.14941374e-05, 1.26251983e-04, 4.53199391e-05,
                                  -1.36593080e-05, 3.43329921e-05, 1.68794850e-06,
                                  -2.21433365e-05, -1.23757432e-04, -3.15084690e-05,
                                  -2.46682084e-05, 2.18231489e-05, 3.31677329e-05,
                                  1.69207158e-06, -4.76929490e-05, 2.47722219e-05,
                                  9.19669583e-06, -8.92527751e-05, 3.24163702e-05,
                                  6.66684646e-05, 9.47047738e-05, 6.21318031e-05,
                                  1.43903217e-05, -4.57681635e-05, -2.16236294e-05,
                                  6.05946889e-06, 1.16722586e-05, -1.78084138e-05,
                                  -8.97299615e-05, -2.61336900e-05, 3.46529523e-05,
                                  3.18849779e-05, -5.88439434e-05, 1.38231480e-04,
                                  -4.42174423e-05, 5.08043922e-05, -6.66681008e-05,
                                  1.38951900e-05, 6.80091616e-05, -1.23025084e-05,
                                  4.17552255e-05],
                                 [9.49401419e-06, 1.90739265e-05, 6.56960110e-05,
                                  -4.93599437e-05, -1.16793044e-05, -4.21300319e-05,
                                  -4.38965435e-05, -5.06620427e-06, 2.09389054e-05,
                                  -4.27646555e-05, -2.80807890e-05, 1.64731118e-05,
                                  -3.34221040e-05, -3.30901503e-06, 4.99596717e-05,
                                  -4.50394873e-05, -4.25584085e-06, 3.57373960e-06,
                                  5.81055701e-06, -5.89593412e-07, -1.13535070e-05,
                                  1.45581607e-05, 2.09581103e-05, 1.72940327e-05,
                                  3.68817819e-05, 6.53399038e-05, 2.05439155e-05,
                                  8.91967829e-06, 2.30790556e-05, -2.35129633e-07,
                                  9.37260756e-06, -4.27530904e-05, 1.73840272e-05,
                                  7.74195541e-06, 1.86754150e-05, 3.45440130e-05,
                                  -2.00022332e-05, 2.09872469e-05, -2.86882346e-06,
                                  -1.01048563e-05, -3.01084747e-05, 4.11823094e-05,
                                  -1.58757503e-05, 9.08889488e-05, -3.36722042e-06,
                                  1.83340580e-05, -3.04139212e-05, 7.45704665e-06,
                                  2.05170181e-06, 2.02014617e-05, -6.36234199e-06,
                                  -5.29401295e-05, 5.27058683e-06, 4.55776280e-05,
                                  2.43298073e-05, 9.46173259e-06, 6.61690356e-05,
                                  -1.65850361e-05, 3.61970488e-05, 4.87889565e-07,
                                  8.64537560e-06, 4.98160480e-05, -2.32789007e-05,
                                  1.78827886e-05],
                                 [1.43057827e-04, -4.44612087e-05, -6.49614158e-05,
                                  -6.25576868e-05, -3.83419974e-05, 3.51661656e-05,
                                  -1.29701410e-04, -1.62674623e-05, 1.01051664e-05,
                                  6.47966153e-05, -1.13867492e-04, 2.22215222e-05,
                                  -7.65149616e-06, 7.43268756e-05, -2.35968419e-05,
                                  -3.42666426e-05, -1.32323476e-05, 1.11024492e-05,
                                  -3.37448873e-05, 2.30112892e-05, -8.67997733e-05,
                                  -1.23952368e-05, 1.17417010e-04, -3.28528440e-05,
                                  -2.24592477e-05, 1.15961717e-04, 8.26371761e-05,
                                  -1.36068062e-04, -7.09364103e-05, 4.61931586e-06,
                                  -6.22197786e-06, 8.07744364e-05, 9.34720301e-05,
                                  -8.87700480e-06, 5.53884456e-05, -8.07494434e-05,
                                  9.44123258e-07, 6.28386188e-05, -1.03329230e-05,
                                  7.70002080e-05, 6.76722557e-05, 1.14403756e-05,
                                  -1.80115258e-05, 5.83545661e-05, 4.89570084e-05,
                                  3.79923840e-05, 5.14637259e-06, 7.28233354e-05,
                                  -4.85976489e-05, -2.66200313e-05, -1.34291950e-05,
                                  5.24254938e-05, 1.65319416e-05, -2.76208866e-05,
                                  5.62779951e-06, 2.23442403e-05, -3.98618904e-05,
                                  5.66873859e-05, 7.67959064e-05, 1.31192006e-04,
                                  -2.86334998e-05, 6.44897318e-06, 2.72060402e-06,
                                  2.52412556e-05],
                                 [2.84876023e-05, -1.10890032e-05, -4.49159343e-05,
                                  -2.89382333e-05, -3.91892827e-05, -5.80514825e-05,
                                  -1.04868253e-04, 1.91987074e-05, -6.32686715e-05,
                                  -1.49142352e-05, -5.75787017e-05, 1.24368262e-05,
                                  -9.98018641e-06, 1.65708170e-05, 4.37372146e-05,
                                  -1.28095098e-05, -5.31971637e-07, -2.86972841e-06,
                                  -2.47431617e-05, -1.55568375e-06, 1.33329777e-05,
                                  1.68853894e-05, 1.00783720e-04, -6.38273123e-05,
                                  1.79290646e-06, 9.35466524e-05, 7.62347099e-06,
                                  -8.44511742e-05, 2.71914519e-06, 1.38615260e-05,
                                  1.35487944e-05, 1.96406404e-06, 5.05116332e-05,
                                  1.73582084e-05, 3.80473321e-05, -7.77070672e-05,
                                  4.28038766e-05, 6.61049053e-05, 5.28695000e-06,
                                  5.00978967e-05, 2.46462969e-05, 3.44838736e-05,
                                  4.58245540e-06, 2.91285305e-05, 1.43549269e-05,
                                  5.54573489e-06, -2.01013936e-05, 3.24869106e-05,
                                  -4.04004713e-05, -1.28425299e-05, 1.98764683e-05,
                                  -2.33825381e-06, 1.68295846e-05, -2.95287355e-05,
                                  2.57042429e-05, 5.43877877e-06, -1.53028177e-05,
                                  6.70858717e-05, 1.42318995e-05, 3.92777001e-05,
                                  -3.93951959e-05, 3.47994828e-05, -5.00203023e-05,
                                  1.71665633e-05]], dtype=np.float32)

    decode_output_1 = np.array([[2.94203510e-05, -5.36229663e-06, 6.34292228e-05,
                                 -4.66644524e-05, 4.42733881e-06, -3.07248501e-06,
                                 -1.83742977e-05, -6.86288513e-06, 5.28169548e-05,
                                 -1.22802803e-05, -2.86392697e-05, 4.59555049e-05,
                                 -1.52265511e-05, -7.92623268e-06, 3.48873582e-05,
                                 -1.66570462e-05, -3.10142859e-05, -1.76831945e-05,
                                 -5.71523060e-06, 2.60284323e-05, -8.90285264e-06,
                                 1.65236324e-05, 7.39054849e-06, -7.71275245e-06,
                                 1.36148901e-05, 7.76158486e-05, 3.93778246e-05,
                                 3.81484679e-05, 1.18642110e-05, 5.13427949e-06,
                                 3.47423847e-05, -3.68839355e-05, -1.88473350e-05,
                                 3.90182407e-07, -1.36598792e-05, 4.68672479e-05,
                                 -2.52590416e-05, -1.47496812e-05, 2.75770435e-06,
                                 -6.60356181e-06, -5.65422415e-05, 5.47635937e-05,
                                 -3.98120865e-05, 9.75557123e-05, 2.89859563e-05,
                                 1.23135133e-05, -1.35048767e-05, 4.40821668e-06,
                                 2.49140630e-05, 6.13136990e-06, -4.86342142e-05,
                                 -1.58069288e-05, 8.68530515e-06, 6.86945132e-05,
                                 3.59779624e-05, -9.61003479e-06, 6.73899776e-05,
                                 -1.17567988e-05, 4.41567063e-05, 3.76296271e-06,
                                 1.24967710e-05, 3.97767180e-05, -1.37137331e-05,
                                 3.59719743e-05],
                                [3.45780359e-06, -4.76866171e-06, -6.69746514e-05,
                                 1.08002005e-05, -4.54376423e-05, -4.78452894e-05,
                                 -5.67951138e-05, 4.37259114e-05, -9.72900234e-05,
                                 -6.07286020e-05, -5.10808241e-05, 3.79057128e-05,
                                 -4.72735564e-06, -1.70841358e-05, 5.99223786e-05,
                                 -1.63897002e-05, 2.24612199e-06, -1.38757132e-05,
                                 -3.33144126e-05, -3.19658757e-05, 5.42563112e-06,
                                 3.84533860e-06, 8.31207799e-05, -5.52753809e-05,
                                 -2.77277541e-05, 5.59701766e-05, -5.58598867e-06,
                                 -6.61516897e-05, 4.75820125e-05, 5.57046314e-06,
                                 1.75445020e-05, 2.73535752e-05, 5.84891313e-05,
                                 3.20350373e-05, 4.20762080e-05, -8.09027624e-05,
                                 8.35031606e-05, 5.61741144e-05, -8.05898253e-06,
                                 2.97976421e-05, 4.88479745e-05, 3.13622040e-05,
                                 2.71338249e-05, 2.85179049e-05, -1.71893644e-05,
                                 4.40456597e-06, -2.72962807e-05, 4.70874656e-05,
                                 -3.83619445e-05, -4.60210285e-05, 3.69368136e-05,
                                 -4.20523720e-05, 3.30058651e-06, -1.57899212e-05,
                                 6.12896838e-06, 2.32406674e-05, 1.43065490e-05,
                                 5.24259813e-05, 2.66047737e-05, 4.56154630e-05,
                                 -1.71753290e-05, 1.81363339e-05, -4.19770040e-05,
                                 -8.17446835e-06]], dtype=np.float32)

    prefill_output_2 = np.array([[-4.54976835e-05, -1.17711657e-04, -5.14050334e-05,
                                  -6.83533654e-05, -2.16778899e-05, -4.00271383e-05,
                                  -3.32936688e-05, 2.29605539e-05, 3.44702057e-05,
                                  -4.49593863e-05, -7.34170826e-05, 9.89834825e-06,
                                  -7.44706413e-05, 4.11930705e-06, 4.22717458e-05,
                                  -6.46424451e-05, -4.66139063e-05, 5.51207049e-05,
                                  -4.37949057e-05, -5.64865786e-06, 3.82605213e-05,
                                  -1.45379254e-05, 5.30062243e-06, -2.37652803e-05,
                                  3.01490581e-05, 1.54218360e-05, -4.63407669e-05,
                                  -8.16426473e-05, -5.41356385e-05, -9.77777563e-06,
                                  -9.50162212e-05, -1.20898614e-04, 8.72888922e-05,
                                  4.25108046e-05, -9.72098787e-05, -1.02172344e-04,
                                  8.73486351e-06, -9.62190006e-06, 3.54833101e-05,
                                  -8.09452758e-05, -7.55481014e-05, -3.01690143e-05,
                                  -2.80331751e-05, 1.18073049e-05, 2.16980679e-05,
                                  4.99941707e-05, -2.18116547e-05, -5.29777062e-05,
                                  3.74395240e-05, -7.09479646e-05, 3.39394464e-05,
                                  -1.07073021e-04, -4.79830378e-05, 8.63588230e-06,
                                  7.83802170e-06, -1.03024366e-04, 3.48897229e-05,
                                  1.23633319e-04, -6.16188190e-05, 7.06601713e-05,
                                  -2.78099978e-05, 3.75602212e-06, -3.35309051e-05,
                                  -5.43052047e-05],
                                 [3.97725780e-05, -3.07583796e-05, 9.17402122e-06,
                                  -2.24778705e-05, 2.49078439e-05, -1.03103594e-05,
                                  2.92933291e-05, -3.26207584e-07, 4.43769641e-05,
                                  -4.00253266e-05, -8.33276135e-05, -1.83020711e-06,
                                  -5.06792931e-05, 8.99141560e-06, 6.48916011e-06,
                                  1.73002388e-06, -6.12503618e-06, 9.25523727e-06,
                                  -1.90589913e-06, 1.09238517e-05, 1.51287659e-05,
                                  3.97535441e-05, -3.35462537e-05, 5.13602681e-05,
                                  1.92273164e-05, -2.63910842e-05, -5.11290455e-05,
                                  -2.35793414e-05, -5.65370028e-05, 1.62290053e-05,
                                  -7.98380715e-05, -1.11749017e-04, -9.43477698e-06,
                                  1.42966492e-05, 2.68718622e-06, -3.44342589e-05,
                                  -2.57137458e-06, -1.23400996e-05, 7.26287835e-05,
                                  -6.33822638e-05, -1.59797655e-05, -7.12747533e-06,
                                  -2.45621122e-05, 2.82465553e-05, 2.91974975e-05,
                                  2.16478966e-05, -4.22123958e-05, -2.42009137e-05,
                                  1.62434953e-05, -2.35578191e-05, -2.02179053e-05,
                                  -2.93987159e-05, -9.52924165e-06, -7.82179904e-06,
                                  1.24390590e-05, -1.01283304e-05, 2.17082688e-05,
                                  8.29088603e-05, -1.08077838e-05, 4.73675864e-05,
                                  -3.32549644e-05, 1.85873596e-05, -2.40674854e-05,
                                  -9.29001817e-06],
                                 [1.44197897e-04, 5.95039637e-05, -3.15447578e-05,
                                  6.01922802e-05, -4.61218115e-05, -5.27396915e-05,
                                  6.45090840e-05, 9.10419330e-05, 1.37268347e-04,
                                  2.16031185e-05, 3.32774798e-05, 2.55908617e-05,
                                  -1.53572437e-05, -1.56326405e-05, -2.14551710e-06,
                                  1.00341204e-04, 1.02503895e-04, -1.85255522e-05,
                                  5.71625642e-05, 6.51904993e-05, 2.25564145e-05,
                                  1.16208183e-04, -1.38919604e-05, 1.66132071e-04,
                                  -1.33664842e-04, -1.07256652e-04, -6.56931006e-05,
                                  5.96554164e-05, 6.73388713e-05, -7.57015077e-05,
                                  7.31835025e-05, -2.37249515e-06, -5.08724515e-05,
                                  5.18189190e-05, 2.07554418e-04, 5.08404355e-06,
                                  -2.96790204e-05, -6.71617381e-05, 5.01858340e-05,
                                  5.76945167e-05, 5.63641624e-05, 4.66814054e-05,
                                  3.99280179e-05, 5.06890865e-05, 8.38160631e-05,
                                  -7.19475502e-05, 1.47599485e-05, 6.59490397e-05,
                                  1.74093177e-04, 3.98902266e-05, -2.00636758e-04,
                                  6.24412423e-05, -1.45305348e-05, -6.45982655e-05,
                                  2.50495468e-05, 9.82189613e-06, 1.30722969e-04,
                                  6.29705610e-05, 7.73973079e-06, 5.61842899e-05,
                                  1.85940698e-05, 7.40375835e-05, 7.41065305e-05,
                                  -3.58311518e-05],
                                 [4.60205229e-05, 6.62673338e-05, -1.05460085e-05,
                                  2.38315624e-05, 1.44803907e-05, -1.57447050e-06,
                                  3.48076610e-05, 4.78515940e-05, 5.77544233e-05,
                                  2.03429900e-05, 2.03695708e-05, -3.22352457e-06,
                                  5.56607702e-05, -1.25271999e-05, -2.72190064e-05,
                                  4.60159572e-05, 5.42461421e-05, -5.81150198e-06,
                                  7.23578851e-05, 3.66961285e-05, -2.08504343e-06,
                                  9.68217428e-05, -3.38161663e-05, 7.50513136e-05,
                                  -1.11466892e-04, -2.84176040e-05, -2.30377918e-05,
                                  3.29688883e-05, 5.42836133e-05, -3.85211533e-05,
                                  3.45974586e-05, -3.10430914e-05, -5.78143954e-05,
                                  -1.85637145e-06, 1.04568469e-04, 4.15095474e-06,
                                  -3.62006176e-05, -5.54482576e-05, 1.10853216e-05,
                                  8.11286463e-06, 3.66888453e-05, 1.56402875e-05,
                                  -1.32053319e-05, -1.45510157e-05, 2.36724009e-05,
                                  -2.66628049e-05, 1.75008281e-05, 6.84956758e-05,
                                  5.75353151e-05, 5.51972480e-05, -7.29626045e-05,
                                  5.23972412e-05, 1.27463491e-05, -2.94074798e-05,
                                  4.70455670e-05, 6.44659667e-05, 5.14301209e-05,
                                  8.51124241e-06, -1.04915489e-05, 6.81683832e-06,
                                  -2.22562521e-05, 6.14828168e-05, 2.49914829e-05,
                                  3.22499545e-05]], dtype=np.float32)

    decode_output_2 = np.array([[4.6741250e-05, -3.0301186e-05, -2.7919381e-05, -2.4746618e-05,
                                 2.1522601e-05, -4.5019431e-05, 8.7622002e-06, 2.0723572e-05,
                                 5.5089880e-05, -5.5801909e-05, -3.9258124e-05, -2.1380823e-05,
                                 -2.9022094e-05, 2.4279749e-05, 2.1042726e-05, 9.0319363e-06,
                                 3.9009123e-05, 1.0439881e-05, 1.3547102e-05, 8.2676543e-06,
                                 8.1436719e-06, 1.3009090e-05, -1.5163456e-05, 4.9737788e-05,
                                 -7.1281365e-06, -5.1747134e-05, -2.2586337e-05, -8.3717650e-06,
                                 -4.0486979e-05, -5.8811847e-06, -4.5481494e-05, -7.5779761e-05,
                                 2.2921333e-05, 2.2264219e-05, 1.7241706e-05, -2.9609406e-05,
                                 -3.3190408e-06, -4.6541991e-05, 3.4206059e-05, -3.0471780e-05,
                                 1.5180687e-05, -5.3317008e-06, -4.5742654e-06, 3.1106010e-05,
                                 1.3307816e-05, -7.5097601e-06, -3.7082562e-05, -4.8276638e-06,
                                 5.4086227e-05, -1.2385093e-05, -3.8333310e-05, -2.0047819e-05,
                                 -2.9704754e-05, 5.8301021e-06, -1.3417582e-05, -1.9183804e-05,
                                 3.4972520e-05, 6.1324092e-05, -1.5627235e-07, 1.8154291e-05,
                                 -5.4318944e-06, 3.0106707e-05, 1.0881942e-05, -1.6869601e-05],
                                [3.2493161e-05, 5.5838715e-05, 2.5370304e-05, 1.2004887e-05,
                                 4.8812111e-05, 2.2970677e-05, 3.5203499e-05, -5.6148806e-06,
                                 -2.6832142e-06, -2.0612429e-06, -3.7373024e-06, -4.1473268e-05,
                                 3.9182829e-05, 6.3850844e-06, -3.2081891e-05, 2.6401583e-05,
                                 2.5973635e-05, -2.2938284e-05, 3.9311919e-05, 1.6188920e-05,
                                 -6.2334025e-06, 5.1802122e-05, -5.1651391e-06, 4.9668088e-05,
                                 -4.0603372e-05, -2.5189818e-05, 2.2700804e-05, 2.1606340e-05,
                                 2.7262146e-05, 4.1041059e-07, 8.0513394e-07, -1.9021283e-05,
                                 -6.2119332e-05, -2.3868377e-05, 8.5665255e-05, 3.4490058e-05,
                                 -2.1369886e-05, -4.2073083e-05, -2.7748392e-05, 1.6611502e-05,
                                 4.4805649e-05, -6.0247038e-07, -1.1755194e-05, -1.6656131e-05,
                                 1.9337034e-05, -3.8748112e-05, 1.2256173e-05, 4.2489079e-05,
                                 9.2310611e-06, 1.2447832e-05, -4.0743870e-07, 2.9845280e-05,
                                 5.9647459e-06, -1.1102197e-06, 7.1130940e-05, 5.5083296e-05,
                                 1.0366221e-05, 9.9901772e-06, 1.9193949e-05, 1.5707485e-06,
                                 -6.2833083e-06, 1.9374889e-05, -9.2856499e-06, 2.9506509e-05]], dtype=np.float32)

    return {
        "prefill_output_1": prefill_output_1,
        "decode_output_1": decode_output_1,
        "prefill_output_2": prefill_output_2,
        "decode_output_2": decode_output_2,
    }


def get_gpu_data() -> dict[str, np.ndarray]:
    """Generate gpu data for test."""
    prefill_output_1 = np.array([[4.7088e-06, -3.2663e-05, 6.3896e-05, -2.7537e-05, -4.0531e-05,
                                  -6.4850e-05, -6.4373e-05, 1.0729e-05, 3.2425e-05, -4.9829e-05,
                                  -3.0160e-05, 6.1512e-05, -9.0599e-05, 2.7895e-05, 4.1723e-05,
                                  -2.5034e-06, 4.5061e-05, 1.9789e-05, -5.9009e-06, -4.2915e-05,
                                  3.7432e-05, 7.3910e-06, 6.1989e-05, -5.0068e-05, 2.1577e-05,
                                  1.2589e-04, 4.5538e-05, -1.3888e-05, 3.4332e-05, 1.7285e-06,
                                  -2.2173e-05, -1.2398e-04, -3.1471e-05, -2.4676e-05, 2.1577e-05,
                                  3.3379e-05, 1.9073e-06, -4.7445e-05, 2.4557e-05, 9.2983e-06,
                                  -8.9169e-05, 3.2663e-05, 6.6757e-05, 9.4891e-05, 6.1989e-05,
                                  1.4603e-05, -4.6015e-05, -2.1219e-05, 5.9605e-06, 1.1623e-05,
                                  -1.7524e-05, -8.9645e-05, -2.6226e-05, 3.4571e-05, 3.1710e-05,
                                  -5.9128e-05, 1.3828e-04, -4.4346e-05, 5.1022e-05, -6.6280e-05,
                                  1.3769e-05, 6.8188e-05, -1.2398e-05, 4.1723e-05],
                                 [9.4175e-06, 1.9073e-05, 6.6280e-05, -4.9353e-05, -1.1802e-05,
                                  -4.1962e-05, -4.4107e-05, -5.2452e-06, 2.0623e-05, -4.2677e-05,
                                  -2.8133e-05, 1.6451e-05, -3.3617e-05, -3.3975e-06, 5.0068e-05,
                                  -4.4823e-05, -4.0531e-06, 3.3379e-06, 5.9605e-06, -5.3644e-07,
                                  -1.1504e-05, 1.4663e-05, 2.0862e-05, 1.7524e-05, 3.6955e-05,
                                  6.5327e-05, 2.0623e-05, 8.8215e-06, 2.3007e-05, -1.7881e-07,
                                  9.4771e-06, -4.2915e-05, 1.7524e-05, 7.7486e-06, 1.8597e-05,
                                  3.4571e-05, -2.0027e-05, 2.1100e-05, -2.8610e-06, -1.0192e-05,
                                  -3.0279e-05, 4.1246e-05, -1.5855e-05, 9.1076e-05, -3.4571e-06,
                                  1.8358e-05, -3.0518e-05, 7.7486e-06, 2.0266e-06, 2.0385e-05,
                                  -6.3181e-06, -5.2691e-05, 5.2452e-06, 4.5538e-05, 2.4438e-05,
                                  9.4771e-06, 6.6280e-05, -1.6570e-05, 3.6240e-05, 7.7486e-07,
                                  8.4043e-06, 4.9829e-05, -2.3127e-05, 1.7643e-05],
                                 [1.4305e-04, -4.4584e-05, -6.4850e-05, -6.2466e-05, -3.8385e-05,
                                  3.4809e-05, -1.2970e-04, -1.6093e-05, 1.0133e-05, 6.4850e-05,
                                  -1.1396e-04, 2.2411e-05, -7.5102e-06, 7.4387e-05, -2.3484e-05,
                                  -3.4094e-05, -1.3053e-05, 1.1146e-05, -3.3617e-05, 2.2888e-05,
                                  -8.6308e-05, -1.2398e-05, 1.1730e-04, -3.2663e-05, -2.2531e-05,
                                  1.1635e-04, 8.2493e-05, -1.3638e-04, -7.1049e-05, 4.5300e-06,
                                  -6.4373e-06, 8.0585e-05, 9.3460e-05, -8.5235e-06, 5.5552e-05,
                                  -8.1062e-05, 1.0729e-06, 6.2466e-05, -1.0490e-05, 7.7248e-05,
                                  6.7234e-05, 1.1504e-05, -1.7762e-05, 5.8651e-05, 4.9114e-05,
                                  3.8147e-05, 5.1260e-06, 7.2956e-05, -4.8876e-05, -2.6584e-05,
                                  -1.3351e-05, 5.2452e-05, 1.6451e-05, -2.7776e-05, 5.8413e-06,
                                  2.2411e-05, -3.9577e-05, 5.6505e-05, 7.6771e-05, 1.3161e-04,
                                  -2.8849e-05, 6.4969e-06, 2.8610e-06, 2.5272e-05],
                                 [2.8372e-05, -1.1027e-05, -4.4823e-05, -2.8610e-05, -3.9101e-05,
                                  -5.8413e-05, -1.0538e-04, 1.9431e-05, -6.3419e-05, -1.5020e-05,
                                  -5.7459e-05, 1.2279e-05, -9.9540e-06, 1.6332e-05, 4.3869e-05,
                                  -1.2815e-05, -2.9802e-07, -2.9206e-06, -2.4676e-05, -1.6689e-06,
                                  1.3530e-05, 1.6928e-05, 1.0109e-04, -6.3896e-05, 1.9073e-06,
                                  9.3937e-05, 7.3910e-06, -8.4400e-05, 2.8610e-06, 1.3888e-05,
                                  1.3649e-05, 1.9073e-06, 5.0545e-05, 1.7524e-05, 3.8147e-05,
                                  -7.7724e-05, 4.2915e-05, 6.6280e-05, 5.3644e-06, 5.0306e-05,
                                  2.4438e-05, 3.4571e-05, 5.0068e-06, 2.9206e-05, 1.4305e-05,
                                  5.5432e-06, -2.0146e-05, 3.2663e-05, -4.0531e-05, -1.2636e-05,
                                  2.0146e-05, -2.3246e-06, 1.6809e-05, -2.9683e-05, 2.5749e-05,
                                  5.4836e-06, -1.5259e-05, 6.7234e-05, 1.4246e-05, 3.9339e-05,
                                  -3.9577e-05, 3.5048e-05, -4.9829e-05, 1.7166e-05]], dtype=np.float16)

    decode_output_1 = np.array([[2.933e-05, -5.364e-06, 6.390e-05, -4.673e-05, 4.232e-06,
                                 -2.682e-06, -1.848e-05, -6.974e-06, 5.293e-05, -1.210e-05,
                                 -2.861e-05, 4.625e-05, -1.550e-05, -8.166e-06, 3.481e-05,
                                 -1.645e-05, -3.076e-05, -1.800e-05, -5.484e-06, 2.611e-05,
                                 -8.941e-06, 1.681e-05, 7.331e-06, -7.391e-06, 1.365e-05,
                                 7.772e-05, 3.958e-05, 3.815e-05, 1.186e-05, 5.245e-06,
                                 3.481e-05, -3.719e-05, -1.895e-05, 1.192e-07, -1.371e-05,
                                 4.697e-05, -2.539e-05, -1.466e-05, 2.623e-06, -6.795e-06,
                                 -5.698e-05, 5.507e-05, -3.982e-05, 9.775e-05, 2.921e-05,
                                 1.240e-05, -1.353e-05, 4.768e-06, 2.503e-05, 6.437e-06,
                                 -4.864e-05, -1.585e-05, 8.762e-06, 6.914e-05, 3.600e-05,
                                 -9.656e-06, 6.771e-05, -1.186e-05, 4.411e-05, 3.934e-06,
                                 1.228e-05, 3.982e-05, -1.365e-05, 3.576e-05],
                                [3.517e-06, -5.007e-06, -6.723e-05, 1.085e-05, -4.554e-05,
                                 -4.816e-05, -5.698e-05, 4.411e-05, -9.775e-05, -6.080e-05,
                                 -5.126e-05, 3.791e-05, -4.649e-06, -1.705e-05, 6.032e-05,
                                 -1.645e-05, 2.205e-06, -1.395e-05, -3.338e-05, -3.219e-05,
                                 5.603e-06, 3.815e-06, 8.345e-05, -5.555e-05, -2.778e-05,
                                 5.627e-05, -5.722e-06, -6.628e-05, 4.768e-05, 5.603e-06,
                                 1.776e-05, 2.754e-05, 5.889e-05, 3.242e-05, 4.244e-05,
                                 -8.106e-05, 8.392e-05, 5.627e-05, -7.987e-06, 3.016e-05,
                                 4.911e-05, 3.147e-05, 2.742e-05, 2.861e-05, -1.729e-05,
                                 4.530e-06, -2.730e-05, 4.721e-05, -3.862e-05, -4.625e-05,
                                 3.719e-05, -4.220e-05, 3.278e-06, -1.609e-05, 6.080e-06,
                                 2.325e-05, 1.454e-05, 5.245e-05, 2.670e-05, 4.578e-05,
                                 -1.729e-05, 1.824e-05, -4.196e-05, -7.927e-06]], dtype=np.float16)

    prefill_output_2 = np.array([[-4.5300e-05, -1.1826e-04, -5.1022e-05, -6.8665e-05, -2.1458e-05,
                                  -3.9816e-05, -3.3140e-05, 2.2769e-05, 3.4332e-05, -4.5300e-05,
                                  -7.3433e-05, 9.8944e-06, -7.4863e-05, 4.0531e-06, 4.2200e-05,
                                  -6.4373e-05, -4.6253e-05, 5.5075e-05, -4.3392e-05, -5.6028e-06,
                                  3.8624e-05, -1.4305e-05, 5.3048e-06, -2.3603e-05, 3.0160e-05,
                                  1.5259e-05, -4.6968e-05, -8.2016e-05, -5.4359e-05, -9.8348e-06,
                                  -9.5367e-05, -1.2112e-04, 8.7261e-05, 4.2677e-05, -9.7275e-05,
                                  -1.0204e-04, 9.0003e-06, -9.5963e-06, 3.5524e-05, -8.0585e-05,
                                  -7.5340e-05, -3.0279e-05, -2.8253e-05, 1.1683e-05, 2.1935e-05,
                                  5.0068e-05, -2.1815e-05, -5.2929e-05, 3.7909e-05, -7.1049e-05,
                                  3.3855e-05, -1.0729e-04, -4.7922e-05, 9.0599e-06, 7.9870e-06,
                                  -1.0300e-04, 3.4809e-05, 1.2398e-04, -6.1989e-05, 7.0572e-05,
                                  -2.7776e-05, 4.2915e-06, -3.3617e-05, -5.4359e-05],
                                 [3.9816e-05, -3.0994e-05, 9.4771e-06, -2.2411e-05, 2.5034e-05,
                                  -1.0133e-05, 2.9445e-05, -2.9802e-07, 4.4584e-05, -4.0293e-05,
                                  -8.3447e-05, -1.7285e-06, -5.1022e-05, 9.0003e-06, 6.5565e-06,
                                  1.8477e-06, -5.9605e-06, 9.3579e-06, -1.6689e-06, 1.1027e-05,
                                  1.5259e-05, 3.9816e-05, -3.3617e-05, 5.1498e-05, 1.9312e-05,
                                  -2.6345e-05, -5.1498e-05, -2.3842e-05, -5.6744e-05, 1.6212e-05,
                                  -8.0109e-05, -1.1206e-04, -9.5963e-06, 1.4544e-05, 2.8014e-06,
                                  -3.4571e-05, -2.4438e-06, -1.2338e-05, 7.2956e-05, -6.3419e-05,
                                  -1.6093e-05, -7.1526e-06, -2.4796e-05, 2.8372e-05, 2.9445e-05,
                                  2.1577e-05, -4.2200e-05, -2.4199e-05, 1.6689e-05, -2.3484e-05,
                                  -2.0385e-05, -2.9564e-05, -9.4771e-06, -7.5698e-06, 1.2636e-05,
                                  -1.0192e-05, 2.1815e-05, 8.2970e-05, -1.1086e-05, 4.7445e-05,
                                  -3.3379e-05, 1.8954e-05, -2.4199e-05, -9.3579e-06],
                                 [1.4400e-04, 5.9366e-05, -3.1471e-05, 5.9843e-05, -4.6253e-05,
                                  -5.2929e-05, 6.4373e-05, 9.1076e-05, 1.3733e-04, 2.1577e-05,
                                  3.3379e-05, 2.5749e-05, -1.5378e-05, -1.5497e-05, -2.0266e-06,
                                  1.0061e-04, 1.0252e-04, -1.8358e-05, 5.7220e-05, 6.5327e-05,
                                  2.2531e-05, 1.1587e-04, -1.4126e-05, 1.6594e-04, -1.3351e-04,
                                  -1.0729e-04, -6.5327e-05, 5.9605e-05, 6.7234e-05, -7.5817e-05,
                                  7.2956e-05, -2.1458e-06, -5.0783e-05, 5.1498e-05, 2.0790e-04,
                                  5.0068e-06, -2.9802e-05, -6.7234e-05, 5.0068e-05, 5.7697e-05,
                                  5.6267e-05, 4.6730e-05, 3.9816e-05, 5.0783e-05, 8.3447e-05,
                                  -7.2002e-05, 1.4722e-05, 6.6280e-05, 1.7452e-04, 3.9816e-05,
                                  -2.0027e-04, 6.1989e-05, -1.4424e-05, -6.4373e-05, 2.4796e-05,
                                  9.8944e-06, 1.3065e-04, 6.2943e-05, 8.1658e-06, 5.5790e-05,
                                  1.8716e-05, 7.4387e-05, 7.3910e-05, -3.6240e-05],
                                 [4.6015e-05, 6.6280e-05, -1.0610e-05, 2.3603e-05, 1.4186e-05,
                                  -1.6093e-06, 3.4809e-05, 4.7922e-05, 5.7936e-05, 2.0385e-05,
                                  2.0504e-05, -3.2783e-06, 5.5552e-05, -1.2517e-05, -2.7180e-05,
                                  4.6015e-05, 5.4359e-05, -5.6624e-06, 7.2002e-05, 3.6716e-05,
                                  -2.0266e-06, 9.6798e-05, -3.4094e-05, 7.4863e-05, -1.1110e-04,
                                  -2.8372e-05, -2.3246e-05, 3.2902e-05, 5.4359e-05, -3.8624e-05,
                                  3.4332e-05, -3.0994e-05, -5.7697e-05, -1.9073e-06, 1.0443e-04,
                                  4.0531e-06, -3.6001e-05, -5.5313e-05, 1.1206e-05, 8.0466e-06,
                                  3.6478e-05, 1.5855e-05, -1.3292e-05, -1.4603e-05, 2.3603e-05,
                                  -2.6703e-05, 1.7405e-05, 6.8665e-05, 5.7697e-05, 5.5075e-05,
                                  -7.2956e-05, 5.1975e-05, 1.2696e-05, -2.9445e-05, 4.6730e-05,
                                  6.4373e-05, 5.1498e-05, 8.5831e-06, -1.0312e-05, 6.6161e-06,
                                  -2.2173e-05, 6.1512e-05, 2.5034e-05, 3.1948e-05]], dtype=np.float16)

    decode_output_2 = np.array([[4.673e-05, -3.028e-05, -2.778e-05, -2.480e-05, 2.158e-05,
                                 -4.506e-05, 8.941e-06, 2.086e-05, 5.531e-05, -5.603e-05,
                                 -3.934e-05, -2.134e-05, -2.933e-05, 2.432e-05, 2.098e-05,
                                 9.120e-06, 3.910e-05, 1.055e-05, 1.377e-05, 8.285e-06,
                                 8.345e-06, 1.329e-05, -1.526e-05, 5.007e-05, -7.153e-06,
                                 -5.198e-05, -2.289e-05, -8.464e-06, -4.077e-05, -5.960e-06,
                                 -4.578e-05, -7.629e-05, 2.277e-05, 2.241e-05, 1.740e-05,
                                 -2.980e-05, -3.278e-06, -4.673e-05, 3.433e-05, -3.040e-05,
                                 1.526e-05, -5.364e-06, -4.888e-06, 3.099e-05, 1.353e-05,
                                 -7.629e-06, -3.719e-05, -4.768e-06, 5.460e-05, -1.234e-05,
                                 -3.862e-05, -2.027e-05, -2.968e-05, 6.080e-06, -1.323e-05,
                                 -1.931e-05, 3.529e-05, 6.151e-05, -3.576e-07, 1.824e-05,
                                 -5.484e-06, 3.052e-05, 1.079e-05, -1.693e-05],
                                [3.242e-05, 5.603e-05, 2.539e-05, 1.198e-05, 4.864e-05,
                                 2.301e-05, 3.552e-05, -5.364e-06, -2.444e-06, -2.146e-06,
                                 -3.576e-06, -4.125e-05, 3.910e-05, 6.378e-06, -3.195e-05,
                                 2.670e-05, 2.587e-05, -2.265e-05, 3.910e-05, 1.609e-05,
                                 -6.318e-06, 5.174e-05, -5.126e-06, 4.959e-05, -4.053e-05,
                                 -2.527e-05, 2.265e-05, 2.182e-05, 2.730e-05, 3.576e-07,
                                 7.749e-07, -1.895e-05, -6.199e-05, -2.384e-05, 8.535e-05,
                                 3.457e-05, -2.146e-05, -4.172e-05, -2.766e-05, 1.657e-05,
                                 4.506e-05, -2.980e-07, -1.186e-05, -1.669e-05, 1.931e-05,
                                 -3.862e-05, 1.228e-05, 4.244e-05, 9.179e-06, 1.258e-05,
                                 -4.768e-07, 2.968e-05, 5.960e-06, -1.252e-06, 7.105e-05,
                                 5.507e-05, 1.037e-05, 9.894e-06, 1.943e-05, 1.431e-06,
                                 -6.437e-06, 1.943e-05, -9.298e-06, 2.933e-05]], dtype=np.float16)

    return {
        "prefill_output_1": prefill_output_1,
        "decode_output_1": decode_output_1,
        "prefill_output_2": prefill_output_2,
        "decode_output_2": decode_output_2,
    }


GOLDEN_DATA = get_golden()
GPU_DATA = get_gpu_data()
